# ------------------------------------------------------------------------------
# Gets predicted values using ensemble model
# From: stress_test_medicare repo <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/master/code/03_analysis/01_build-models/07_predict-ensemble.R>
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 07_predict_ensemble.sh {dataset} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(optparse)
library(here)
library(testit) # assert()

u <- modules::use(here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--dataset", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Helpers ----------------------------------------------------------------------
get_score_name <- function(model_file) {
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("p__{str_replace(model_file, '.rds', '')}{restriction_lab}")
}

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
subscores_dir <-  file.path(
  paths$modeling$dir, "subscores", arg_list$split, arg_list$restriction
)
models_dir <- file.path(
  paths$modeling$dir, "models", arg_list$split, arg_list$restriction
)
save_dir <- file.path(
  paths$modeling$dir, "prediction", arg_list$split, arg_list$restriction
)
assert("Predictions directory exists", dir.exists(save_dir))

# Load Data --------------------------------------------------------------------
message("Loading data...")
ids <- readRDS(glue(paths$modeling[[arg_list$dataset]])) %>%
  select(ed_enc_id) %>%
  setDT()
subscores <- readRDS(
  file.path(
    subscores_dir,
    glue("subscores_{arg_list$dataset}_set.rds")
  )
) %>%
  setDT()

assert("IDs and subscores same length", nrow(ids) == nrow(subscores))

# Locate Models ----------------------------------------------------------------
message("Locating ensemble models...")
model_files <- list.files(models_dir)
ecg_files <- str_subset(model_files, "ecg")
ensemble_files <- str_subset(model_files, "ensemble__") %>%
  setdiff(ecg_files)

# Predict GLM ------------------------------------------------------------------
message("Predicting with GLM...")
walk(ensemble_files, ~ {
  model <- readRDS(file.path(models_dir, .x))
  score_name <- get_score_name(.x)
  ids[, (score_name) := predict(model, subscores, type = "response")]
})

# Save Data --------------------------------------------------------------------
save_fp <- file.path(save_dir, glue("scores_{arg_list$dataset}_set.rds"))
message("Saving results to ", save_fp, "...")
write_rds(ids, save_fp)

# Done -------------------------------------------------------------------------
message("Done.")
